package redisDB

import (
	"context"
	"gitee.com/ling-bin/go-utils/idCounter"
	"github.com/go-redis/redis/v8"
	"strconv"
	"time"
)

const (
	LuaString =`local key = KEYS[1];
local id = ARGV[5]
local now = tonumber(ARGV[1]); 
local ttl = tonumber(ARGV[2]); 
local expired = tonumber(ARGV[3]); 
local max = tonumber(ARGV[4]); 
redis.call('zremrangebyscore', key, 0, expired); 
local current = tonumber(redis.call('zcard', key)); 
local next = current + 1; 
if next > max then
	return 0;  
else 
	redis.call('zadd', key, now, id); 
	redis.call('pexpire', key, ttl); 
	return next; 
end
`
)

//RedisLimit redis限速器
type RedisLimit struct {
	store   redis.Cmdable // redis客户端
}

//NewRedisLimit 实例化限速器
func NewRedisLimit(store redis.Cmdable) *RedisLimit {
	return &RedisLimit{ store: store}
}

//DelLimit 删除限速器
//key 需要清理的限制标识
func (rl *RedisLimit) DelLimit(key string) (bool,error) {
	ctx := context.Background()
	del := rl.store.Del(ctx, key)
	_, err := del.Result()
	if err != nil {
		return false, err
	}
	return true, nil
}

//IsLimit 是否触发限速器 , true 触发的限速，false 没有触发限速
//key 需要限制的标识
//maxLimit 单位时间内最大访问次数
//timeout  时间范围
func (rl *RedisLimit) IsLimit(key string,maxLimit int,timeout time.Duration) (bool,error) {
	now := time.Now().UnixMilli()
	expired := now - timeout.Milliseconds()
	ctx := context.Background()
	keys := []string{key}
	ages := []string{
		strconv.FormatInt(now, 10),
		strconv.FormatInt(timeout.Milliseconds(), 10),
		strconv.FormatInt(expired, 10),
		strconv.Itoa(maxLimit),
		idCounter.NewObjectID().Hex(),
	}
	resp := rl.store.Eval(ctx, LuaString, keys, ages)
	reply, err := resp.Int()
	if err != nil {
		return false, err
	}
	if reply == 0 {
		return true, nil
	}
	return false, nil
}